"""
lakeFS Authentication Module

Includes authentication util functions
"""

import base64
import datetime
import json
from typing import Any, Optional, TYPE_CHECKING
from urllib.parse import urlparse, parse_qs

from lakefs_sdk import ExternalLoginInformation
from lakefs_sdk.client import LakeFSClient

from lakefs.config import ClientConfig
from lakefs.exceptions import api_exception_handler

DEFAULT_AWS_REGION = "us-east-1"

if TYPE_CHECKING:
    import boto3

def access_token_from_aws_iam_role(sdk_client: LakeFSClient,
                                   lakefs_host: str,
                                   boto3_session: "boto3.Session",
                                   aws_provider_auth_params: ClientConfig.AWSIAMProviderConfig) -> tuple[Any, datetime]:
    """
    Generate an access token for lakeFS authentication using AWS IAM role.
    :param sdk_client: LakeFSClient
    :param lakefs_host: LakeFS API URL
    :param boto3_session: Session
    :param aws_provider_auth_params: ClientConfig.AWSIAMProviderConfig
    :return: An access token for lakeFS authentication.
    """
    presigned_ttl = aws_provider_auth_params.url_presign_ttl_seconds
    token_ttl_seconds = aws_provider_auth_params.token_ttl_seconds
    token_req_headers = aws_provider_auth_params.token_request_headers

    identity_token = _get_identity_token(boto3_session, lakefs_host, presign_expiry=presigned_ttl,
                                         additional_headers=token_req_headers)
    external_login_information = ExternalLoginInformation(
        token_expiration_duration=token_ttl_seconds,
        identity_request={
            "identity_token": identity_token
        }
    )

    with api_exception_handler():
        auth_token = sdk_client.auth_api.external_principal_login(external_login_information)
    expiration_time = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(seconds=token_ttl_seconds)
    reset_token_time = expiration_time - datetime.timedelta(minutes=5)
    return auth_token.token, reset_token_time

def _get_identity_token(
        session: "boto3.Session",
        lakefs_host: str,
        additional_headers: Optional[dict[str, str]],
        presign_expiry
) -> str:
    """
   Generate the identity token required for lakeFS authentication from an AWS session.

   This function uses the STS client to generate a presigned URL for the `get_caller_identity` action,
    extracts the required values from the URL,
   and creates a base64-encoded JSON object with these values.

   :param session: A boto3 session object with the necessary AWS credentials and region information.
   :return: A base64-encoded JSON string containing the required authentication information.
   :raises ValueError: If the session does not have a region name set.
   """

    # this method should only be called when installing the aws-iam additional requirement
    from botocore.client import Config  # pylint: disable=import-outside-toplevel, import-error
    from botocore.signers import RequestSigner  # pylint: disable=import-outside-toplevel, import-error

    sts_client = session.client('sts', config=Config(signature_version='v4'))
    endpoint = sts_client.meta.endpoint_url
    service_id = sts_client.meta.service_model.service_id
    region = _extract_region_from_endpoint(endpoint)
    # signer is used because the presigned URL generated by the STS does not support additional headers
    signer = RequestSigner(
        service_id,
        region,
        'sts',
        'v4',
        session.get_credentials(),
        session.events
    )
    endpoint_with_params = f"{endpoint}/?Action=GetCallerIdentity&Version=2011-06-15"
    if additional_headers is None:
        additional_headers = {
            'X-LakeFS-Server-ID': lakefs_host,
        }
    params = {
        'method': 'POST',
        'url': endpoint_with_params,
        'body': {},
        'headers': additional_headers,
        'context': {}
    }
    presigned_url = signer.generate_presigned_url(
        params,
        region_name=region,
        expires_in=presign_expiry,
        operation_name=''
    )
    parsed_url = urlparse(presigned_url)
    query_params = parse_qs(parsed_url.query)

    # Extract values from query parameters
    json_object = {
        "method": "POST",
        "host": parsed_url.hostname,
        "region": region,
        "action": query_params['Action'][0],
        "date": query_params['X-Amz-Date'][0],
        "expiration_duration": query_params['X-Amz-Expires'][0],
        "access_key_id": query_params['X-Amz-Credential'][0].split('/')[0],
        "signature": query_params['X-Amz-Signature'][0],
        "signed_headers": query_params.get('X-Amz-SignedHeaders', [''])[0].split(';'),
        "version": query_params['Version'][0],
        "algorithm": query_params['X-Amz-Algorithm'][0],
        "security_token": query_params.get('X-Amz-Security-Token', [None])[0]
    }
    json_string = json.dumps(json_object)
    return base64.b64encode(json_string.encode('utf-8')).decode('utf-8')

def _extract_region_from_endpoint(endpoint):
    """
    Extract the region name from an STS endpoint URL.
    for example: https://sts.eu-central-1.amazonaws.com/ -> eu-central-1
    and for example: https://sts.amazonaws.com/ -> DEFAULT_REGION

    :param endpoint: The endpoint URL of the STS client.
    :return: The region name extracted from the endpoint URL.
    """

    parts = endpoint.split('.')
    if len(parts) == 4:
        return parts[1]
    if len(parts) > 4:
        return parts[2]
    return DEFAULT_AWS_REGION
